from meas_sequences.eef_sequence_generator import EEFSequenceGenerator
from meas_sequences.eef_sequence_generator import IntegratorMode
from meas_sequences.eef_sequence_generator import AnalogControlMode
from meas_sequences.eef_sequence_generator import AnalogRange
from meas_sequences.eef_sequence_generator import SetRegisterCondition
from meas_sequences.eef_sequence_generator import TriggerInput
from meas_sequences.eef_sequence_generator import TriggerOutput
from meas_sequences.eef_sequence_generator import SignalOutput
from meas_sequences.eef_sequence_generator import MeasurementChannel


FI_RESULT_BUFFER_SIZE = 18      # Number of results per FI measurement
FI_MEASUREMENT_WINDOW = 50e-6   # Measurement window in seconds per flash


class FIResultData:
    """
    Class to parse and store the raw result data of a FI measurement.
    """
    def __init__(self, results):
        if len(results) != FI_RESULT_BUFFER_SIZE:
            raise ValueError(f"Invalid number of results: {len(results)}. Expected {FI_RESULT_BUFFER_SIZE}.")
        # Valid Data:
        self.ref_analog_low   = results[8]
        self.ref_analog_high  = results[9]
        self.pmt1_counts      = (results[5] << 32) + results[4]
        self.pmt1_analog_low  = results[10]
        self.pmt1_analog_high = results[11]
        self.pmt2_counts      = (results[7] << 32) + results[6]
        self.pmt2_analog_low  = results[12]
        self.pmt2_analog_high = results[13]
        # Invalid Data:
        self._invalid_pmt1_counts      = (results[15] << 32) + results[14]
        self._invalid_pmt1_analog_low  = results[0]
        self._invalid_pmt1_analog_high = results[1]
        self._invalid_pmt2_counts      = (results[17] << 32) + results[16]
        self._invalid_pmt2_analog_low  = results[2]
        self._invalid_pmt2_analog_high = results[3]


def fi_single_mode_sequence(number_of_flashes, frequency_of_flashes, charging_time_us, analog_limit):
    seq_gen = EEFSequenceGenerator()
    seq_gen.set_data_register(register=0, value=0)
    seq_gen.clear_result_buffer_range(address=0, size=FI_RESULT_BUFFER_SIZE)
    seq_gen.sequence += fi_core_sequence(number_of_flashes, frequency_of_flashes, charging_time_us, analog_limit)
    seq_gen.stop(0)
    return seq_gen.sequence


def fi_core_sequence(number_of_flashes, frequency_of_flashes, charging_time_us, analog_limit, step_counter_delay=None):
    one_us_delay = 1
    conversion_delay_us = 12
    switch_delay_us = 0.25
    reset_switch_delay_us = 1
    input_gate_delay = 1
    fixed_range_us = 20
    window_us = round(FI_MEASUREMENT_WINDOW * 1e6)

    analog_limit_register = 10

    total_flash_time_us = round(1000000 / frequency_of_flashes)
    remaining_time_us = (total_flash_time_us
                         - charging_time_us
                         - window_us
                         - 2 * conversion_delay_us
                         - fixed_range_us
                         - switch_delay_us
                         - reset_switch_delay_us
                         - input_gate_delay
                         - one_us_delay)

    if remaining_time_us < 0:
        raise ValueError(f"Invalid FI measurment parameters. Flash frequency can not be reached.")

    seq_gen = EEFSequenceGenerator()

    if analog_limit > 0:
        seq_gen.set_data_register(analog_limit_register, analog_limit)
    seq_gen.reset_signal_outputs(SignalOutput.InputGatePMT1 | SignalOutput.InputGatePMT2 | SignalOutput.HVGatePMT1 | SignalOutput.HVGatePMT2)
    seq_gen.set_signal_outputs(SignalOutput.HVGatePMT1 | SignalOutput.HVGatePMT2)

### Flash Loop ####
    seq_gen.loop(number_of_flashes)
    seq_gen.set_analog_control(pd1ref=AnalogControlMode.FullOffsetReset, pmt1=AnalogControlMode.FullOffsetReset, pmt2=AnalogControlMode.FullOffsetReset)
    # Start charging the flash lamp and resetting the analog integrators
    seq_gen.timer_wait_and_restart_us(charging_time_us)
    seq_gen.set_signal_outputs(SignalOutput.Flash)
    seq_gen.set_integrator_mode(pd1ref=IntegratorMode.FullReset, pmt1=IntegratorMode.FullReset, pmt2=IntegratorMode.FullReset)
    seq_gen.reset_signal_outputs(SignalOutput.InputGatePMT1 | SignalOutput.InputGatePMT2)
    # Measure the high range offset
    seq_gen.timer_wait_and_restart_us(conversion_delay_us)
    seq_gen.set_trigger_outputs(TriggerOutput.SamplePD1 | TriggerOutput.SamplePMT1 | TriggerOutput.SamplePMT2)
    # Start the integrator in low range (250 ns before the low range offset is measured)
    seq_gen.timer_wait_and_restart_us(switch_delay_us)
    seq_gen.set_integrator_mode(pd1ref=IntegratorMode.IntegrateLowResetHigh, pmt1=IntegratorMode.IntegrateLowResetHigh, pmt2=IntegratorMode.IntegrateLowResetHigh)
    seq_gen.set_analog_control(pd1ref=AnalogControlMode.LoadOffset, pmt1=AnalogControlMode.LoadOffset, pmt2=AnalogControlMode.LoadOffset)
    # Measure the low range offset, then switch to auto range
    seq_gen.timer_wait_and_restart_us(reset_switch_delay_us)
    seq_gen.set_trigger_outputs(TriggerOutput.SamplePD1 | TriggerOutput.SamplePMT1 | TriggerOutput.SamplePMT2)
    seq_gen.set_integrator_mode(pd1ref=IntegratorMode.IntegrateAutoRange, pmt1=IntegratorMode.IntegrateAutoRange, pmt2=IntegratorMode.IntegrateAutoRange)

    if step_counter_delay is not None:
        seq_gen.step_counter_wait_and_restart(step_counter_delay)

    seq_gen.timer_wait_and_restart_us(input_gate_delay)
    seq_gen.set_signal_outputs(SignalOutput.InputGatePMT1 | SignalOutput.InputGatePMT2)
    seq_gen.timer_wait_and_restart_us(one_us_delay)
    seq_gen.pulse_counter_control_reset(MeasurementChannel.PMT1)
    seq_gen.pulse_counter_control_reset(MeasurementChannel.PMT2)

    # Trigger the flash
    seq_gen.reset_signal_outputs(SignalOutput.Flash)

### Window Loop ###
    seq_gen.loop(window_us)
    seq_gen.timer_wait_and_restart_us(one_us_delay)
    seq_gen.pulse_counter_control_add(MeasurementChannel.PMT1, use_correction=True)
    seq_gen.pulse_counter_control_add(MeasurementChannel.PMT2, use_correction=True)
    seq_gen.loop_end()
### Window Loop End ###

    seq_gen.reset_signal_outputs(SignalOutput.InputGatePMT1 | SignalOutput.InputGatePMT2)
    # Switch to fixed range and load the low range offset, measured before
    seq_gen.timer_wait_and_restart_us(fixed_range_us)
    seq_gen.set_integrator_mode(pd1ref=IntegratorMode.IntegrateFixedRange, pmt1=IntegratorMode.IntegrateFixedRange, pmt2=IntegratorMode.IntegrateFixedRange)
    seq_gen.set_analog_control(pd1ref=AnalogControlMode.LoadOffset, pmt1=AnalogControlMode.LoadOffset, pmt2=AnalogControlMode.LoadOffset)

    seq_gen.timer_wait_and_restart_us(conversion_delay_us)
    seq_gen.set_trigger_outputs(TriggerOutput.SamplePD1 | TriggerOutput.SamplePMT1 | TriggerOutput.SamplePMT2)

    if (number_of_flashes > 1):
        seq_gen.timer_wait_and_restart_us(remaining_time_us)
    else:
        seq_gen.timer_wait()

    seq_gen.set_address_register_from_data_register(register=0, data_register=0)    # A(0) = D(0)
    if analog_limit > 0:
        seq_gen.set_address_register(register=1, address=0)     # A(1) = 0
        seq_gen.set_address_register(register=2, address=0)     # A(2) = 0
        seq_gen.set_address_register_if_analog_is_high_range(register=1, address=10, relative=True, channel=MeasurementChannel.PMT1)
        seq_gen.set_address_register_if_analog_is_greater(register=1, address=10, relative=True, channel=MeasurementChannel.PMT1, data_register=analog_limit_register)
        seq_gen.set_address_register_if_analog_is_high_range(register=2, address=10, relative=True, channel=MeasurementChannel.PMT2)
        seq_gen.set_address_register_if_analog_is_greater(register=2, address=10, relative=True, channel=MeasurementChannel.PMT2, data_register=analog_limit_register)
    else:
        seq_gen.set_address_register(register=1, address=10)    # A(1) = 10
        seq_gen.set_address_register(register=2, address=10)    # A(2) = 10

    seq_gen.increment_address_register_from_data_register(register=1, data_register=0)  # A(1) = A(1) + D(0)
    seq_gen.increment_address_register_from_data_register(register=2, data_register=0)  # A(2) = A(2) + D(0)

    seq_gen.get_analog_low_result(MeasurementChannel.PMT1, address=0, relative_to_register=1, add=True)
    seq_gen.get_analog_high_result(MeasurementChannel.PMT1, address=1, relative_to_register=1, add=True)

    seq_gen.get_analog_low_result(MeasurementChannel.PMT2, address=2, relative_to_register=2, add=True)
    seq_gen.get_analog_high_result(MeasurementChannel.PMT2, address=3, relative_to_register=2, add=True)

    seq_gen.get_pulse_counter_result(MeasurementChannel.PMT1, address=4, relative_to_register=1, add=True, dword=True)
    seq_gen.get_pulse_counter_result(MeasurementChannel.PMT2, address=6, relative_to_register=2, add=True, dword=True)

    seq_gen.get_analog_low_result(MeasurementChannel.PD1REF, address=8, relative_to_register=0, add=True)
    seq_gen.get_analog_high_result(MeasurementChannel.PD1REF, address=9, relative_to_register=0, add=True)

    seq_gen.loop_end()
### Flash Loop End ###

    seq_gen.set_integrator_mode(pd1ref=IntegratorMode.FullReset, pmt1=IntegratorMode.FullReset, pmt2=IntegratorMode.FullReset)
    seq_gen.set_analog_control(pd1ref=AnalogControlMode.FullOffsetReset, pmt1=AnalogControlMode.FullOffsetReset, pmt2=AnalogControlMode.FullOffsetReset)

    seq_gen.reset_signal_outputs(SignalOutput.HVGatePMT1 | SignalOutput.HVGatePMT2)

    return seq_gen.sequence